{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Build an APP with Langgraph that uses UC functions as tools\n",
    "\n",
    "Prerequisite: Run this notebook on Databricks platform."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install -qqq unitycatalog-langchain[databricks] langchain_openai langgraph mlflow\n",
    "%restart_python"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {},
     "inputWidgets": {},
     "nuid": "18072ae8-83e5-4880-9b89-a92ee9f67641",
     "showTitle": false,
     "tableResultSettingsMap": {},
     "title": ""
    }
   },
   "source": [
    "#### Create a DatabricksFunctionClient and set as default\n",
    "\n",
    "Create a client in order to create and execute functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "fae8cd2a-6f52-4da2-8dbb-4d95d8284e28",
     "showTitle": false,
     "tableResultSettingsMap": {},
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "from unitycatalog.ai.core.base import set_uc_function_client\n",
    "from unitycatalog.ai.core.databricks import DatabricksFunctionClient\n",
    "\n",
    "client = DatabricksFunctionClient()\n",
    "\n",
    "# sets the default uc function client\n",
    "set_uc_function_client(client)\n",
    "\n",
    "# replace with your own catalog and schema\n",
    "CATALOG = \"ml\"\n",
    "SCHEMA = \"serena_test\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {},
     "inputWidgets": {},
     "nuid": "d5a6f448-21a5-4fe4-82e3-f952df401834",
     "showTitle": false,
     "tableResultSettingsMap": {},
     "title": ""
    }
   },
   "source": [
    "#### Create UC functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "9509aeb8-d4c7-41b9-af8f-473ffdd5379c",
     "showTitle": false,
     "tableResultSettingsMap": {},
     "title": ""
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "FunctionExecutionResult(error=None, format='SCALAR', value='2\\n', truncated=None)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def execute_python_code(code: str) -> str:\n",
    "    \"\"\"\n",
    "    Executes the given python code and returns its stdout.\n",
    "    Remember the code should print the final result to stdout.\n",
    "\n",
    "    Args:\n",
    "      code: Python code to execute. Remember to print the final result to stdout.\n",
    "    \"\"\"\n",
    "    import sys\n",
    "    from io import StringIO\n",
    "\n",
    "    stdout = StringIO()\n",
    "    sys.stdout = stdout\n",
    "    exec(code)\n",
    "    return stdout.getvalue()\n",
    "\n",
    "\n",
    "function_info = client.create_python_function(\n",
    "    func=execute_python_code, catalog=CATALOG, schema=SCHEMA, replace=True\n",
    ")\n",
    "python_execution_function_name = function_info.full_name\n",
    "\n",
    "# test execution\n",
    "client.execute_function(python_execution_function_name, {\"code\": \"print(1+1)\"})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "6a0c6a93-192f-46c8-bba3-ec81748207ec",
     "showTitle": false,
     "tableResultSettingsMap": {},
     "title": ""
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'MLflow is an open-source platform for managing the end-to-end machine learning (ML) lifecycle. It was developed by Databricks and is designed to help data scientists and machine learning engineers manage the complexities of building, deploying, and maintaining ML models.\\n\\nMLflow provides a set of tools and APIs that enable users to:\\n\\n1. **Track experiments**: Record and manage the parameters, metrics, and artifacts of ML experiments, making it easier to reproduce and compare results.\\n2. **Manage models**: Store, version, and deploy ML models, including support for multiple frameworks such as TensorFlow, PyTorch, and Scikit-learn.\\n3. **Deploy models**: Deploy ML models to various environments, including cloud, on-premises, and edge devices.\\n4. **Monitor and manage models**: Track the performance of deployed models, detect data drift, and retrain models as needed.\\n\\nMLflow consists of four main components:\\n\\n1. **MLflow Tracking**: Records and manages the parameters, metrics, and artifacts of ML experiments.\\n2. **MLflow Projects**: Provides a standardized way to package and deploy ML models.\\n3. **MLflow Models**: Manages the deployment and serving of ML models.\\n4. **MLflow Registry**: Provides a centralized repository for storing and managing ML models.\\n\\nMLflow supports a wide range of ML frameworks and libraries, including:\\n\\n* TensorFlow\\n* PyTorch\\n* Scikit-learn\\n* XGBoost\\n* LightGBM\\n* Keras\\n\\nMLflow is widely used in the industry and has been adopted by many organizations, including those in finance, healthcare, and technology.\\n\\nSome of the benefits of using MLflow include:\\n\\n* **Improved collaboration**: MLflow enables data scientists and engineers to collaborate more effectively by providing a standardized way to manage ML experiments and models.\\n* **Increased productivity**: MLflow automates many tasks, such as tracking experiments and deploying models, freeing up users to focus on more strategic tasks.\\n* **Better model management**: MLflow provides a centralized repository for managing ML models, making it easier to track and manage model performance over time.\\n\\nOverall, MLflow is a powerful tool for managing the ML lifecycle, and its adoption is growing rapidly in the industry.'"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ask_ai_function_name = f\"{CATALOG}.{SCHEMA}.ask_ai\"\n",
    "sql_body = f\"\"\"CREATE OR REPLACE FUNCTION {ask_ai_function_name}(question STRING COMMENT 'question to ask')\n",
    "RETURNS STRING\n",
    "COMMENT 'answer the question using Meta-Llama-3.1-70B-Instruct model'\n",
    "RETURN SELECT ai_gen(question)\n",
    "\"\"\"\n",
    "client.create_function(sql_function_body=sql_body)\n",
    "result = client.execute_function(ask_ai_function_name, {\"question\": \"What is MLflow?\"})\n",
    "result.value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "9ddd7433-13f4-4cd4-b48f-d9a2681489a4",
     "showTitle": false,
     "tableResultSettingsMap": {},
     "title": ""
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "FunctionExecutionResult(error=None, format='SCALAR', value='MLflow: End-to-End Machine Learning Lifecycle Management', truncated=None)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "summarization_function_name = f\"{CATALOG}.{SCHEMA}.summarize\"\n",
    "sql_body = f\"\"\"CREATE OR REPLACE FUNCTION {summarization_function_name}(text STRING COMMENT 'content to parse', max_words INT COMMENT 'max number of words in the response, must be non-negative integer, if set to 0 then no limit')\n",
    "RETURNS STRING\n",
    "COMMENT 'summarize the content and limit response to max_words'\n",
    "RETURN SELECT ai_summarize(text, max_words)\n",
    "\"\"\"\n",
    "client.create_function(sql_function_body=sql_body)\n",
    "# test execution\n",
    "client.execute_function(summarization_function_name, {\"text\": result.value, \"max_words\": 20})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "05af21a7-1c09-4197-b4cb-5c80602183e0",
     "showTitle": false,
     "tableResultSettingsMap": {},
     "title": ""
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "FunctionExecutionResult(error=None, format='SCALAR', value='Hola', truncated=None)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "translate_function_name = f\"{CATALOG}.{SCHEMA}.translate\"\n",
    "sql_body = f\"\"\"CREATE OR REPLACE FUNCTION {translate_function_name}(content STRING COMMENT 'content to translate', language STRING COMMENT 'target language')\n",
    "RETURNS STRING\n",
    "COMMENT 'translate the content to target language, currently only english <-> spanish translation is supported'\n",
    "RETURN SELECT ai_translate(content, language)\n",
    "\"\"\"\n",
    "client.create_function(sql_function_body=sql_body)\n",
    "# test execution\n",
    "client.execute_function(translate_function_name, {\"content\": \"Hello\", \"language\": \"es\"})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {},
     "inputWidgets": {},
     "nuid": "5e61bb13-e40d-4d22-a0aa-f58367adba24",
     "showTitle": false,
     "tableResultSettingsMap": {},
     "title": ""
    }
   },
   "source": [
    "#### Create tools"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "a383e331-7cd5-43a0-9728-c3ece9f2670c",
     "showTitle": false,
     "tableResultSettingsMap": {},
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "python_execution_function_name = f\"{CATALOG}.{SCHEMA}.execute_python_code\"\n",
    "ask_ai_function_name = f\"{CATALOG}.{SCHEMA}.ask_ai\"\n",
    "summarization_function_name = f\"{CATALOG}.{SCHEMA}.summarize\"\n",
    "translate_function_name = f\"{CATALOG}.{SCHEMA}.translate\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "952ed06f-89c6-4495-81d6-ca1fe2a24c9b",
     "showTitle": false,
     "tableResultSettingsMap": {},
     "title": ""
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[UnityCatalogTool(name='ml__serena_test__execute_python_code', description='Executes the given python code and returns its stdout. Remember the code should print the final result to stdout.', args_schema=<class 'unitycatalog.ai.core.utils.function_processing_utils.ml__serena_test__execute_python_code__params'>, func=<function UCFunctionToolkit.uc_function_to_langchain_tool.<locals>.func at 0x7f566fe069e0>, uc_function_name='ml.serena_test.execute_python_code', client_config={'warehouse_id': None, 'profile': None}),\n",
       " UnityCatalogTool(name='ml__serena_test__ask_ai', description='answer the question using Meta-Llama-3.1-70B-Instruct model', args_schema=<class 'unitycatalog.ai.core.utils.function_processing_utils.ml__serena_test__ask_ai__params'>, func=<function UCFunctionToolkit.uc_function_to_langchain_tool.<locals>.func at 0x7f566fe07520>, uc_function_name='ml.serena_test.ask_ai', client_config={'warehouse_id': None, 'profile': None}),\n",
       " UnityCatalogTool(name='ml__serena_test__summarize', description='summarize the content and limit response to max_words', args_schema=<class 'unitycatalog.ai.core.utils.function_processing_utils.ml__serena_test__summarize__params'>, func=<function UCFunctionToolkit.uc_function_to_langchain_tool.<locals>.func at 0x7f566fe07a30>, uc_function_name='ml.serena_test.summarize', client_config={'warehouse_id': None, 'profile': None}),\n",
       " UnityCatalogTool(name='ml__serena_test__translate', description='translate the content to target language, currently only english <-> spanish translation is supported', args_schema=<class 'unitycatalog.ai.core.utils.function_processing_utils.ml__serena_test__translate__params'>, func=<function UCFunctionToolkit.uc_function_to_langchain_tool.<locals>.func at 0x7f566fd2c0d0>, uc_function_name='ml.serena_test.translate', client_config={'warehouse_id': None, 'profile': None})]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from unitycatalog.ai.langchain.toolkit import UCFunctionToolkit\n",
    "\n",
    "toolkit = UCFunctionToolkit(\n",
    "    function_names=[\n",
    "        python_execution_function_name,\n",
    "        ask_ai_function_name,\n",
    "        summarization_function_name,\n",
    "        translate_function_name,\n",
    "    ]\n",
    ")\n",
    "tools = toolkit.tools\n",
    "tools"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {},
     "inputWidgets": {},
     "nuid": "aafb2a37-27d6-42b9-8220-b71b5ab062ea",
     "showTitle": false,
     "tableResultSettingsMap": {},
     "title": ""
    }
   },
   "source": [
    "#### Use the tools in Langgraph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "d3d6a36b-42c4-480f-a446-752a11ca2bd0",
     "showTitle": false,
     "tableResultSettingsMap": {},
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ[\"OPENAI_API_KEY\"] = \"<REPLACE WITH YOUR OWN OPENAI_API_KEY>\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "f765a706-501d-4f02-abd2-fa96987b869c",
     "showTitle": false,
     "tableResultSettingsMap": {},
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "import mlflow\n",
    "\n",
    "# We're enabling mlflow autologging for langchain so that the tracing feature is enabled.\n",
    "# Generated traces are incredibly useful for debugging and validation of proper agent functionality\n",
    "mlflow.langchain.autolog()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "eb3d0cb8-aafc-4023-a423-dca9119ff314",
     "showTitle": false,
     "tableResultSettingsMap": {},
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "from typing import Literal\n",
    "\n",
    "from langchain_openai.chat_models import ChatOpenAI\n",
    "from langgraph.graph import END, START, MessagesState, StateGraph\n",
    "from langgraph.prebuilt import ToolNode\n",
    "\n",
    "tool_node = ToolNode(tools)\n",
    "model = ChatOpenAI(model=\"gpt-4o-mini\").bind_tools(tools)\n",
    "\n",
    "\n",
    "# Define the function that determines whether to continue or not\n",
    "def should_continue(state: MessagesState) -> Literal[\"tools\", END]:\n",
    "    messages = state[\"messages\"]\n",
    "    last_message = messages[-1]\n",
    "    # If the LLM makes a tool call, then we route to the \"tools\" node\n",
    "    if last_message.tool_calls:\n",
    "        return \"tools\"\n",
    "    # Otherwise, we stop (reply to the user)\n",
    "    return END\n",
    "\n",
    "\n",
    "# Define the function that calls the model\n",
    "def call_model(state: MessagesState):\n",
    "    messages = state[\"messages\"]\n",
    "    response = model.invoke(messages)\n",
    "    # We return a list, because this will get added to the existing list\n",
    "    return {\"messages\": [response]}\n",
    "\n",
    "\n",
    "# Define a new graph\n",
    "workflow = StateGraph(MessagesState)\n",
    "\n",
    "# Define the two nodes we will cycle between\n",
    "workflow.add_node(\"agent\", call_model)\n",
    "workflow.add_node(\"tools\", tool_node)\n",
    "\n",
    "# Set the entrypoint as `agent`\n",
    "# This means that this node is the first one called\n",
    "workflow.add_edge(START, \"agent\")\n",
    "\n",
    "# We now add a conditional edge\n",
    "workflow.add_conditional_edges(\n",
    "    # First, we define the start node. We use `agent`.\n",
    "    # This means these are the edges taken after the `agent` node is called.\n",
    "    \"agent\",\n",
    "    # Next, we pass in the function that will determine which node is called next.\n",
    "    should_continue,\n",
    ")\n",
    "\n",
    "# We now add a normal edge from `tools` to `agent`.\n",
    "# This means that after `tools` is called, `agent` node is called next.\n",
    "workflow.add_edge(\"tools\", \"agent\")\n",
    "\n",
    "app = workflow.compile()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "e74fdf7b-bf4c-4ea1-8909-372fa292700e",
     "showTitle": false,
     "tableResultSettingsMap": {},
     "title": ""
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'MLflow es una plataforma de código abierto para gestionar el ciclo de vida completo del aprendizaje automático. Fue desarrollado por Databricks y está diseñado para ayudar a los científicos de datos y a los ingenieros a manejar las complejidades de la creación, implementación y mantenimiento de modelos de aprendizaje automático.\\n\\nMLflow proporciona herramientas y API que permiten a los científicos de datos:\\n\\n1. **Seguimiento de experimentos**: Realizar un seguimiento y gestionar experimentos, incluida la optimización de hiperparámetros.\\n2. **Gestión de modelos**: Almacenar, versionar e implementar modelos.\\n3. **Implementación de modelos**: Implementar modelos en varios entornos, incluidos la nube y locales.\\n4. **Monitoreo y análisis**: Monitorear y analizar el rendimiento de los modelos.\\n\\nSus principales componentes son:\\n\\n1. **Seguimiento de MLflow**: Para rastrear experimentos.\\n2. **Proyectos de MLflow**: Empaquetar código en proyectos reutilizables.\\n3. **Modelos de MLflow**: Registro de modelos.\\n4. **Servicio de MLflow**: Implementar modelos en diferentes entornos.\\n\\nLos beneficios incluyen mejora en la colaboración, aumento de la reproducibilidad, implementación más rápida y mejor gestión de modelos. MLflow es una herramienta poderosa para gestionar el ciclo de vida del aprendizaje automático.'"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "application/databricks.mlflow.trace": "\"tr-b9ddcdb3bc42480294709255dcea80c7\"",
      "text/plain": [
       "Trace(request_id=tr-b9ddcdb3bc42480294709255dcea80c7)"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "final_state = app.invoke(\n",
    "    {\n",
    "        \"messages\": [\n",
    "            {\n",
    "                \"role\": \"user\",\n",
    "                \"content\": \"What is MLflow? Keep the response concise and reply in Spanish. Try using as many tools as possible\",\n",
    "            }\n",
    "        ]\n",
    "    },\n",
    ")\n",
    "response = final_state[\"messages\"][-1].content\n",
    "response"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "1cc92803-9677-48e5-b7ed-04dd715d58a6",
     "showTitle": false,
     "tableResultSettingsMap": {},
     "title": ""
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Here is the translated explanation in English:\\n\\nMLflow is an open-source platform for managing the end-to-end machine learning lifecycle. It was developed by Databricks and is designed to help data scientists and engineers manage the complexities of creating, deploying, and maintaining machine learning models.\\n\\nMLflow provides tools and APIs that allow data scientists to:\\n\\n1. **Experiment tracking**: Track and manage experiments, including hyperparameter optimization.\\n2. **Model management**: Store, version, and deploy models.\\n3. **Model deployment**: Deploy models in various environments, including cloud and on-premises.\\n4. **Monitoring and analysis**: Monitor and analyze model performance.\\n\\nIts main components are:\\n\\n1. **MLflow tracking**: To track experiments.\\n2. **MLflow projects**: Package code into reusable projects.\\n3. **MLflow models**: Register models.\\n4. **MLflow model serving**: Deploy models in different environments.\\n\\nThe benefits include improved collaboration, increased reproducibility, faster deployment, and better model management. MLflow is a powerful tool for managing the machine learning lifecycle.'"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "application/databricks.mlflow.trace": "\"tr-efebb1ba511e49ee90f2085e4b54a9e6\"",
      "text/plain": [
       "Trace(request_id=tr-efebb1ba511e49ee90f2085e4b54a9e6)"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "final_state = app.invoke(\n",
    "    {\n",
    "        \"messages\": [\n",
    "            {\n",
    "                \"role\": \"user\",\n",
    "                \"content\": f\"Remember to always try using tools. Can you convert the following explanation to English? {response}\",\n",
    "            }\n",
    "        ]\n",
    "    },\n",
    ")\n",
    "final_state[\"messages\"][-1].content"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "141e202a-a8b4-4e20-aea5-19144ce93162",
     "showTitle": false,
     "tableResultSettingsMap": {},
     "title": ""
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'The result of \\\\(2^{10}\\\\) is 1024.'"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "application/databricks.mlflow.trace": "\"tr-52b8a7439b2a4344ae00868cf6af1e8a\"",
      "text/plain": [
       "Trace(request_id=tr-52b8a7439b2a4344ae00868cf6af1e8a)"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "final_state = app.invoke(\n",
    "    {\"messages\": [{\"role\": \"user\", \"content\": \"What is 2**10?\"}]},\n",
    ")\n",
    "final_state[\"messages\"][-1].content"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {},
     "inputWidgets": {},
     "nuid": "3390ecd8-726f-4219-ab85-cc4d61385324",
     "showTitle": false,
     "tableResultSettingsMap": {},
     "title": ""
    }
   },
   "source": [
    "#### Log the model using MLflow"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Save the model to a separate file `app.py` using the magic command `%%writefile`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "60108a27-c47b-4d22-9ad0-c6b2d9a4a26b",
     "showTitle": false,
     "tableResultSettingsMap": {},
     "title": ""
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting app.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile app.py\n",
    "from mlflow.models import set_model\n",
    "from unitycatalog.ai.core.base import set_uc_function_client\n",
    "from unitycatalog.ai.core.databricks import DatabricksFunctionClient\n",
    "from langchain_openai.chat_models import ChatOpenAI\n",
    "from langgraph.prebuilt import ToolNode\n",
    "from langgraph.checkpoint.memory import MemorySaver\n",
    "from langgraph.graph import END, START, StateGraph, MessagesState\n",
    "from typing import Annotated, Literal, TypedDict\n",
    "from unitycatalog.ai.langchain.toolkit import UCFunctionToolkit\n",
    "\n",
    "client = DatabricksFunctionClient()\n",
    "\n",
    "# sets the default uc function client\n",
    "set_uc_function_client(client)\n",
    "\n",
    "# replace with your own catalog and schema\n",
    "CATALOG = \"ml\"\n",
    "SCHEMA = \"serena_test\"\n",
    "\n",
    "python_execution_function_name = f\"{CATALOG}.{SCHEMA}.execute_python_code\"\n",
    "ask_ai_function_name = f\"{CATALOG}.{SCHEMA}.ask_ai\"\n",
    "summarization_function_name = f\"{CATALOG}.{SCHEMA}.summarize\"\n",
    "translate_function_name = f\"{CATALOG}.{SCHEMA}.translate\"\n",
    "toolkit = UCFunctionToolkit(\n",
    "    function_names=[\n",
    "        python_execution_function_name,\n",
    "        ask_ai_function_name,\n",
    "        summarization_function_name,\n",
    "        translate_function_name,\n",
    "    ]\n",
    ")\n",
    "tools = toolkit.tools\n",
    "\n",
    "tool_node = ToolNode(tools)\n",
    "model = ChatOpenAI(model=\"gpt-4o-mini\").bind_tools(tools)\n",
    "\n",
    "def should_continue(state: MessagesState) -> Literal[\"tools\", END]:\n",
    "    messages = state['messages']\n",
    "    last_message = messages[-1]\n",
    "    if last_message.tool_calls:\n",
    "        return \"tools\"\n",
    "    return END\n",
    "\n",
    "def call_model(state: MessagesState):\n",
    "    messages = state['messages']\n",
    "    response = model.invoke(messages)\n",
    "    return {\"messages\": [response]}\n",
    "\n",
    "workflow = StateGraph(MessagesState)\n",
    "workflow.add_node(\"agent\", call_model)\n",
    "workflow.add_node(\"tools\", tool_node)\n",
    "workflow.add_edge(START, \"agent\")\n",
    "workflow.add_conditional_edges(\"agent\", should_continue)\n",
    "workflow.add_edge(\"tools\", 'agent')\n",
    "\n",
    "app = workflow.compile()\n",
    "set_model(app)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "100564c4-4a2e-4d1e-9919-0932613b3a53",
     "showTitle": false,
     "tableResultSettingsMap": {},
     "title": ""
    }
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e3e1fb3ea48f4cecbf02defcbe97c627",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading artifacts:   0%|          | 0/11 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024/11/06 10:59:49 INFO mlflow.models.model: Found the following environment variables used during model inference: [OPENAI_API_KEY]. Please check if you need to set them when deploying the model. To disable this message, set environment variable `MLFLOW_RECORD_ENV_VARS_IN_MODEL_LOGGING` to `false`.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a236f3cd7df74e33ab506ae0b1c24ebb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Uploading artifacts:   0%|          | 0/12 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Registered model 'ml.serena_test.app_with_tools' already exists. Creating a new version of this model...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8a8a5b155f0f4386a5b16c33c2db87c1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Uploading artifacts:   0%|          | 0/12 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Created version '1' of model 'ml.serena_test.app_with_tools'.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "🏃 View run rogue-midge-190 at: https://e2-dogfood.staging.cloud.databricks.com/ml/experiments/3916415516979169/runs/a00aabdcb6464fb09de01016a571cdfc\n",
      "🧪 View experiment at: https://e2-dogfood.staging.cloud.databricks.com/ml/experiments/3916415516979169\n"
     ]
    }
   ],
   "source": [
    "import mlflow\n",
    "from mlflow.models import infer_signature\n",
    "\n",
    "# temp workaround as the current output format is not supported by model signature\n",
    "input_example = {\"messages\": [{\"role\": \"user\", \"content\": \"What is 2**10?\"}]}\n",
    "signature = infer_signature(input_example, \"1024\")\n",
    "\n",
    "with mlflow.start_run():\n",
    "    model_info = mlflow.langchain.log_model(\n",
    "        # Pass the path to the saved model file\n",
    "        \"app.py\",\n",
    "        \"model\",\n",
    "        input_example={\"messages\": [{\"role\": \"user\", \"content\": \"What is 3**10?\"}]},\n",
    "        signature=signature,\n",
    "        pip_requirements=[\n",
    "            \"mlflow\",\n",
    "            \"unitycatalog-langchain[databricks]\",\n",
    "            \"langchain_openai\",\n",
    "            \"langgraph\",\n",
    "        ],\n",
    "        registered_model_name=\"ml.serena_test.app_with_tools\",  # Replace with your own model name\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {},
     "inputWidgets": {},
     "nuid": "62a96a45-254f-41da-8f47-a0ba83ccadf6",
     "showTitle": false,
     "tableResultSettingsMap": {},
     "title": ""
    }
   },
   "source": [
    "#### Validate the model locally prior to serving"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "934afbd8-8c2c-40a6-ba47-c864487c57cf",
     "showTitle": false,
     "tableResultSettingsMap": {},
     "title": ""
    }
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2b1049b3436348d5b9e21dce2b89099d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading artifacts:   0%|          | 0/12 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "[{'messages': [{'content': 'What is 3**10?',\n",
       "    'additional_kwargs': {},\n",
       "    'response_metadata': {},\n",
       "    'type': 'human',\n",
       "    'name': None,\n",
       "    'id': '82772f75-1123-4e1b-99d2-302199a5ef7f',\n",
       "    'example': False},\n",
       "   {'content': '',\n",
       "    'additional_kwargs': {'tool_calls': [{'id': 'call_UlCzoqHx05lOM4YfQsktqYv2',\n",
       "       'function': {'arguments': '{\"code\":\"print(3**10)\"}',\n",
       "        'name': 'ml__serena_test__execute_python_code'},\n",
       "       'type': 'function'}],\n",
       "     'refusal': None},\n",
       "    'response_metadata': {'token_usage': {'completion_tokens': 26,\n",
       "      'prompt_tokens': 286,\n",
       "      'total_tokens': 312,\n",
       "      'completion_tokens_details': {'accepted_prediction_tokens': 0,\n",
       "       'audio_tokens': 0,\n",
       "       'reasoning_tokens': 0,\n",
       "       'rejected_prediction_tokens': 0},\n",
       "      'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}},\n",
       "     'model_name': 'gpt-4o-mini-2024-07-18',\n",
       "     'system_fingerprint': 'fp_0ba0d124f1',\n",
       "     'finish_reason': 'tool_calls',\n",
       "     'logprobs': None},\n",
       "    'type': 'ai',\n",
       "    'name': None,\n",
       "    'id': 'run-4fafb419-26a3-40e1-9532-1ab6683335ab-0',\n",
       "    'example': False,\n",
       "    'tool_calls': [{'name': 'ml__serena_test__execute_python_code',\n",
       "      'args': {'code': 'print(3**10)'},\n",
       "      'id': 'call_UlCzoqHx05lOM4YfQsktqYv2',\n",
       "      'type': 'tool_call'}],\n",
       "    'invalid_tool_calls': [],\n",
       "    'usage_metadata': {'input_tokens': 286,\n",
       "     'output_tokens': 26,\n",
       "     'total_tokens': 312,\n",
       "     'input_token_details': {'audio': 0, 'cache_read': 0},\n",
       "     'output_token_details': {'audio': 0, 'reasoning': 0}}},\n",
       "   {'content': '{\"format\": \"SCALAR\", \"value\": \"59049\\\\n\"}',\n",
       "    'additional_kwargs': {},\n",
       "    'response_metadata': {},\n",
       "    'type': 'tool',\n",
       "    'name': 'ml__serena_test__execute_python_code',\n",
       "    'id': '3f9e0517-f486-4e8a-bcbf-bc5cb3737c55',\n",
       "    'tool_call_id': 'call_UlCzoqHx05lOM4YfQsktqYv2',\n",
       "    'artifact': None,\n",
       "    'status': 'success'},\n",
       "   {'content': 'The result of \\\\( 3^{10} \\\\) is 59049.',\n",
       "    'additional_kwargs': {'refusal': None},\n",
       "    'response_metadata': {'token_usage': {'completion_tokens': 17,\n",
       "      'prompt_tokens': 342,\n",
       "      'total_tokens': 359,\n",
       "      'completion_tokens_details': {'accepted_prediction_tokens': 0,\n",
       "       'audio_tokens': 0,\n",
       "       'reasoning_tokens': 0,\n",
       "       'rejected_prediction_tokens': 0},\n",
       "      'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}},\n",
       "     'model_name': 'gpt-4o-mini-2024-07-18',\n",
       "     'system_fingerprint': 'fp_0ba0d124f1',\n",
       "     'finish_reason': 'stop',\n",
       "     'logprobs': None},\n",
       "    'type': 'ai',\n",
       "    'name': None,\n",
       "    'id': 'run-13617be9-c59b-457c-acff-93b3a8f77327-0',\n",
       "    'example': False,\n",
       "    'tool_calls': [],\n",
       "    'invalid_tool_calls': [],\n",
       "    'usage_metadata': {'input_tokens': 342,\n",
       "     'output_tokens': 17,\n",
       "     'total_tokens': 359,\n",
       "     'input_token_details': {'audio': 0, 'cache_read': 0},\n",
       "     'output_token_details': {'audio': 0, 'reasoning': 0}}}]}]"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "application/databricks.mlflow.trace": "\"tr-b760a4d97e7d40eda469e96f0b34476e\"",
      "text/plain": [
       "Trace(request_id=tr-b760a4d97e7d40eda469e96f0b34476e)"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from mlflow.models import convert_input_example_to_serving_input, validate_serving_input\n",
    "\n",
    "serving_input = convert_input_example_to_serving_input(\n",
    "    {\"messages\": [{\"role\": \"user\", \"content\": \"What is 3**10?\"}]}\n",
    ")\n",
    "validate_serving_input(model_info.model_uri, serving_input=serving_input)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {},
     "inputWidgets": {},
     "nuid": "c03c8948-2d31-4024-811a-de8fc593d9eb",
     "showTitle": false,
     "tableResultSettingsMap": {},
     "title": ""
    }
   },
   "source": [
    "#### Deploy the model to a serving endpoint for production usage\n",
    "\n",
    "Follow [this guidance](https://docs.databricks.com/en/machine-learning/model-serving/create-manage-serving-endpoints.html#create-an-endpoint) to create a serving endpoint using the registered model.\n",
    "\n",
    "**Remember** to set the following environment variables when creating the serving endpoint:\n",
    "- DATABRICKS_HOST\n",
    "- DATABRICKS_TOKEN\n",
    "- OPENAI_API_KEY"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {},
     "inputWidgets": {},
     "nuid": "10d17a13-95be-49a8-8f2f-556eeadb6cfc",
     "showTitle": false,
     "tableResultSettingsMap": {},
     "title": ""
    }
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "application/vnd.databricks.v1+notebook": {
   "dashboards": [],
   "environmentMetadata": {
    "base_environment": "",
    "client": "1"
   },
   "language": "python",
   "notebookMetadata": {
    "mostRecentlyExecutedCommandWithImplicitDF": {
     "commandId": 3916415516979323,
     "dataframes": [
      "_sqldf"
     ]
    },
    "pythonIndentUnit": 2
   },
   "notebookName": "Langchain toolkit example",
   "widgets": {}
  },
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
